"""
简单的RNN使用示范
"""

import torch
from torch.nn import RNN

# 定义RNN执行的循环数
seq_size = 3
# 定义一批的样本数
batch_size = 1
# 定义输入的特征数
input_size = 4
# 定义输出的特征数
hidden_size = 2
# 定义有多少层RNN
num_layers = 1

# 输入的参数
input_data = torch.randn(seq_size, batch_size, input_size)
# 隐层的数据
hidden_data = torch.randn(num_layers, batch_size, hidden_size)

# 定义RNN
myRnn = RNN(input_size, hidden_size, num_layers)

# 计算
# output(seq_size,batch_size,hidden_size)
# hidden_out(num_layers, batch_size, hidden_size)
output, hidden_out = myRnn(input_data, hidden_data)
print(f"output shape: {output.shape}, output: {output}")
print(f"hidden shape: {hidden_out.shape}, hidden_out: {hidden_out}")
